CV Week: Итоговое задание¶

На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.

Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.

В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.

Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.

In [1]:
# # torch 2.4.1+cu124
!pip install diffusers==0.30.2 peft==0.8.2 huggingface_hub==0.23.4
Collecting diffusers==0.30.2
  Downloading diffusers-0.30.2-py3-none-any.whl.metadata (18 kB)
Collecting peft==0.8.2
  Downloading peft-0.8.2-py3-none-any.whl.metadata (25 kB)
Collecting huggingface_hub==0.23.4
  Downloading huggingface_hub-0.23.4-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (7.0.0)
Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (3.15.1)
Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (1.26.4)
Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (2024.5.15)
Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (2.32.3)
Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (0.4.5)
Requirement already satisfied: Pillow in /opt/conda/lib/python3.10/site-packages (from diffusers==0.30.2) (10.3.0)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (21.3)
Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (5.9.3)
Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (6.0.2)
Requirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (2.4.0)
Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (4.46.3)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (4.66.4)
Requirement already satisfied: accelerate>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from peft==0.8.2) (1.1.1)
Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub==0.23.4) (2024.6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub==0.23.4) (4.12.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->peft==0.8.2) (3.1.2)
Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (1.13.3)
Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (3.3)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.8.2) (3.1.4)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.10/site-packages (from importlib-metadata->diffusers==0.30.2) (3.19.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.30.2) (2024.6.2)
Requirement already satisfied: tokenizers<0.21,>=0.20 in /opt/conda/lib/python3.10/site-packages (from transformers->peft==0.8.2) (0.20.3)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft==0.8.2) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.13.0->peft==0.8.2) (1.3.0)
Downloading diffusers-0.30.2-py3-none-any.whl (2.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 29.4 MB/s eta 0:00:00a 0:00:01
Downloading peft-0.8.2-py3-none-any.whl (183 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 183.4/183.4 kB 14.9 MB/s eta 0:00:00
Downloading huggingface_hub-0.23.4-py3-none-any.whl (402 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 402.6/402.6 kB 29.2 MB/s eta 0:00:00
Installing collected packages: huggingface_hub, diffusers, peft
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.26.2
    Uninstalling huggingface-hub-0.26.2:
      Successfully uninstalled huggingface-hub-0.26.2
Successfully installed diffusers-0.30.2 huggingface_hub-0.23.4 peft-0.8.2
In [2]:
from tqdm.auto import tqdm

import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
import functools

%matplotlib inline
import matplotlib.pyplot as plt
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
In [3]:
MODEL_DIR = '/kaggle/input/cv-week-2024/pytorch/default/1/cm_unet_3/'
In [4]:
#---------------------
# Visualization utils
#---------------------

def visualize_images(images):
    assert len(images) == 4
    plt.figure(figsize=(12, 3))
    for i, image in enumerate(images):
        plt.subplot(1, 4, i+1)
        plt.imshow(image)
        plt.axis('off')

    plt.subplots_adjust(wspace=-0.01, hspace=-0.01)


#--------------
# Tensor utils
#--------------

def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

#---------------
# Dataset utils
#---------------

class COCODataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
        """
        Arguments:
            root_dir (string): Директория с картинками
            transform (callable, optional): преобразования, применимые к картинкам
        """
        self.root_dir = root_dir
        self.transform = transform
        self.extensions = (
            ".jpg",
            ".jpeg",
            ".png",
            ".ppm",
            ".bmp",
            ".pgm",
            ".tif",
            ".tiff",
            ".webp",
        )
        sample_dir = os.path.join(root_dir, subset_name)

        # Собираем пути до картинок
        self.samples = sorted(
            [
                os.path.join(sample_dir, fname)
                for fname in os.listdir(sample_dir)
                if fname[-4:] in self.extensions
            ],
            key=lambda x: x.split("/")[-1].split(".")[0],
        )
        self.samples = (
            self.samples if max_cnt is None else self.samples[:max_cnt]
        )  #

        # Собираем промпты
        self.captions = {}
        with open(
            os.path.join(root_dir, f"{subset_name}.csv"), newline="\n"
        ) as csvfile:
            spamreader = csv.reader(csvfile, delimiter=",")
            for i, row in enumerate(spamreader):
                if i == 0:
                    continue
                self.captions[row[1]] = row[2]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample_path = self.samples[idx]
        sample = Image.open(sample_path).convert("RGB")

        if self.transform:
            sample = self.transform(sample)

        return {
            "image": sample,
            "text": self.captions[os.path.basename(sample_path)],
            "idxs": idx,
        }

Stable Diffusion 1.5¶

Задание №1¶

In [15]:
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type =='cuda'
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == 'cuda'
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == 'cuda'

# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()

# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.

Промпты¶

In [16]:
validation_prompts = [
    "A sad puppy with large eyes",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
    "A girl with pale blue hair and a cami tank top",
    "A lighthouse in a giant wave, origami style",
    "belle epoque, christmas, red house in the forest, photo realistic, 8k",
    "A small cactus with a happy face in the Sahara desert",
    "Green commercial building with refrigerator and refrigeration units outside",
]

SD 50 steps¶

In [17]:
generator = torch.Generator('cuda').manual_seed(1)
guidance_scale = 7.5
images = pipe(validation_prompts[0], num_images_per_prompt=4,
             num_inference_steps=50,
             guidance_scale=guidance_scale,
             generator=generator).images

visualize_images(images)
  0%|          | 0/50 [00:00<?, ?it/s]
No description has been provided for this image

SD 4 steps¶

In [18]:
generator = torch.Generator('cuda').manual_seed(1)
guidance_scale = 7.5
images = pipe(validation_prompts[0], num_images_per_prompt=4,
              num_inference_steps=4,
              guidance_scale=guidance_scale,
              generator=generator).images

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

Датасет¶

In [ ]:
!wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
!tar -xzf train2014_5k.tar.gz
In [ ]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    lambda x: 2 * x - 1,
])

dataset = COCODataset(".",
    subset_name="train2014_5k",
    transform=transform,
#     max_cnt=2500
)

assert len(dataset) == 5000 # 2500

batch_size = 8 # Рекоммендуемы размер батча на Colab

train_dataloader = torch.utils.data.DataLoader(
    dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
In [ ]:
@torch.no_grad()
def prepare_batch(batch, pipe):
    """
    Предобработка батча картинок и текстовых промптов.
    Маппим картинки в латентное пространство VAE.
    Извлекаем эмбеды промптов с помощью текстового энкодера.

    Params:

    Return:
        latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
        prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
    """

    # Токенизируем промпты
    text_inputs = pipe.tokenizer(
        batch['text'],
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    # Извлекаем эмбеды промптов с помощью текстового энкодера
    prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]

    # Переводим картинки в латентное пространство VAE
    image = batch['image'].to("cuda", dtype=torch.float16)
    latents = pipe.vae.encode(image).latent_dist.sample()
    latents = latents * pipe.vae.config.scaling_factor
    return latents, prompt_embeds

UNet¶

In [19]:
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet.train()
unet.to('cuda')
assert unet.dtype == torch.float32
assert unet.training

LoRA¶

In [20]:
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
    "to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out",
    "ff.net.0.proj", "ff.net.2", "conv1", "conv2", "conv_shortcut",
    "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj"
]

lora_config = LoraConfig(
    r=64, # задает ранг у матриц A и B в LoRA.
    target_modules=lora_modules
)

CT UNet¶

In [21]:
# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")

# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()

# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)

# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss

Задание №2 (0.5 балла, сдается в контесте)¶

Реализация шага DDIM¶

In [22]:
def ddim_solver_step(model_output, x_t, t, s, scheduler):
    """
    Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
    params:
        model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
        x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
        t: torch.Tensor[B] - номер текущего шага
        s: torch.Tensor[B] - номер следующего шага
        scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
    """
    alphas = torch.sqrt(scheduler.alphas_cumprod).to(x_t.device)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x_t.device)

    sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
    alphas_s = extract_into_tensor(alphas, s, x_t.shape)

    sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
    alphas_t = extract_into_tensor(alphas, t, x_t.shape)

    # Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
    alphas_s[s == 0] = 1.0
    sigmas_s[s == 0] = 0.0

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0

    x_0 = (x_t - model_output * sigmas_t) / alphas_t # x0 оценка на шаге t
    x_s = alphas_s * x_0 + sigmas_s * model_output # Переход на шаг s
    return x_s

Реализация процесса зашумления (q sample)¶

In [23]:
def q_sample(x, t, scheduler, noise=None):
    alphas = torch.sqrt(scheduler.alphas_cumprod).to(x.device)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x.device)

    sigmas_t = extract_into_tensor(sigmas, t, x.shape)
    alphas_t = extract_into_tensor(alphas, t, x.shape)

    if noise is None:
        noise = torch.randn_like(x)

    x_t = alphas_t * x + sigmas_t * noise
    return x_t

Consistency Training¶

Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.

Задание №3¶

Задание №3.1 (0.5 балла, сдается в контесте)¶

В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?

Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:

In [24]:
def get_xs_from_xt_naive(
    x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
    scheduler,
    noise=None,
    **kwargs
):
    """
    Получение точки x_s в CT режиме, т.е., аналитически.
    """
    if x_t is None:
        return q_sample(x_0, s, scheduler, noise)

    alphas = torch.sqrt(scheduler.alphas_cumprod).to(x_0.device)
    sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(x_0.device)

    sigmas_s = extract_into_tensor(sigmas, s, x_0.shape)
    alphas_s = extract_into_tensor(alphas, s, x_0.shape)

    sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
    alphas_t = extract_into_tensor(alphas, t, x_t.shape)

    alphas_s[s == 0] = 1.0
    sigmas_s[s == 0] = 0.0

    alphas_t[t == 0] = 1.0
    sigmas_t[t == 0] = 0.0

    pred_noise = (x_t - alphas_t * x_0) / sigmas_t

    x_s = (
        alphas_s / alphas_t * (x_t - sigmas_t * pred_noise) +
        sigmas_s * pred_noise
    )

    return x_s

Задание №3.2¶

Loss¶

In [25]:
def cm_loss_template(
    latents, prompt_embeds, # батч латентов и текстовых эмбедов
    unet, scheduler,

    # Функции, которые будем постепенно менять из задания к заданию
    loss_fn: callable,
    get_boundary_timesteps: callable,
    get_xs_from_xt: callable,

    num_timesteps=1000,
    step_size=20, # Указываем с каким интервалом берем шаги s и t.
):
    # Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
    assert num_timesteps == 1000
    num_intervals = num_timesteps // step_size

    index = torch.randint(1, num_intervals, (len(latents),), device=latents.device).long() # [1, num_intervals]
    t = step_size * index - 1
    s = torch.clamp(t - step_size, min=0)

    boundary_timesteps = get_boundary_timesteps(s, num_timesteps=num_timesteps)

    # Сэмплируем x_t
    noise = torch.randn_like(latents)
    x_t = q_sample(latents, t, scheduler, noise)

    noise_pred = unet(x_t.float(), t,
        encoder_hidden_states=prompt_embeds.float(),
    ).sample

    # Получаем оценку в граничной точке для x_t
    boundary_pred = ddim_solver_step(noise_pred, x_t, t, boundary_timesteps, scheduler)

    # Получаем сэмпл x_s из x_t
    x_s = get_xs_from_xt(
        latents, x_t, t, s,
        scheduler,
        prompt_embeds=prompt_embeds,
        noise=noise,
    )

    # Предсказание "таргет моделью"
    with torch.no_grad():
        target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample

    # Получаем оценку в граничной точке для x_s
    boundary_target = ddim_solver_step(target_noise_pred, x_s, s, boundary_timesteps, scheduler)

    loss = loss_fn(boundary_pred, boundary_target)
    return loss

Boundary timestamps¶

In [26]:
def get_zero_boundary_timesteps(t, **kwargs):
    """
    Определяем шаги где будут срабатывать граничные условия.
    Для классических СM это t=0.
    """
    return torch.zeros_like(t)
In [27]:
import functools

ct_loss = functools.partial(
    cm_loss_template,
    loss_fn=mse_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_naive
)

assert cm_unet.active_adapter == 'ct'

Задание №4¶

Эффективное обучение¶

Обучающий цикл¶

Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.

Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение

Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.

Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01

In [28]:
def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):

    torch.cuda.empty_cache()

    scaler = torch.amp.GradScaler()

    for i, batch in enumerate(tqdm(train_dataloader)):
        latents, prompt_embeds = prepare_batch(batch, pipe)

        with torch.amp.autocast('cuda', dtype=torch.float16):
            loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler)
            loss = loss / num_grad_accum

        scaler.scale(loss).backward()

        if (i + 1) % num_grad_accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        print(f"Loss: {loss.detach().item()}")
In [29]:
# # Обновляем параметры каждые 2 шага
# num_grad_accum = 2

# train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)

# cm_unet.save_pretrained('cm_unet')

Задание №5¶

Генерация с помощью обученной консистенси модели¶

Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:

In [30]:
@torch.no_grad()
def consistency_sampling(
    pipe,
    prompt,
    num_inference_steps=4,
    generator=None,
    num_images_per_prompt=4,
    guidance_scale=1
):
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)

    device = pipe._execution_device

    # Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
    do_classifier_free_guidance = guidance_scale > 0
    prompt_embeds, null_prompt_embeds = pipe.encode_prompt(
        prompt,
        num_images_per_prompt=num_images_per_prompt,
        do_classifier_free_guidance=do_classifier_free_guidance,
        device=device
    )
    # assert prompt_embeds.dtype == null_prompt_embeds.dtype == torch.float16

    # Настраиваем параметры scheduler-a
    assert pipe.scheduler.config['timestep_spacing'] == 'trailing'
    pipe.scheduler.set_timesteps(num_inference_steps)

    # Создаем батч латентов из N(0,I)
    latents = pipe.prepare_latents(
        batch_size * num_images_per_prompt,
        pipe.vae.config['latent_channels'],  # Number of channels in the latent space
        pipe.vae.config['sample_size'],       # Use sample_size for both height and width
        pipe.vae.config['sample_size'],       # Use sample_size for both height and width
        dtype=torch.float16,
        device=device,
        generator=generator
    )

    for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
        t = torch.tensor([t] * len(latents)).to(device)
        zero_t = torch.tensor([0] * len(latents)).to(device)

        cond_noise_pred = pipe.unet(latents, t, prompt_embeds).sample

        if do_classifier_free_guidance:
            uncond_noise_pred = pipe.unet(latents, t, null_prompt_embeds).sample
            noise_pred = (
                uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
            )

        else:
            noise_pred = cond_noise_pred

        # Получаем x_0 оценку из x_t
        x_0 = ddim_solver_step(noise_pred, latents, t, zero_t, pipe.scheduler)

        if i + 1 < num_inference_steps:
            # Переход на следующий шаг
            s = pipe.scheduler.timesteps[i+1]
            s = torch.tensor([s] * len(latents)).to(device)

            latents = q_sample(x_0, s, pipe.scheduler)
        else:
            latents = x_0

        latents = latents.half()

    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    do_denormalize = [True] * image.shape[0]
    image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
    return image

Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.

Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

img

In [31]:
cm_unet = PeftModel.from_pretrained(
    model=unet,
    model_id=f'{MODEL_DIR}/ct',
    config=lora_config,
    adapter_name="ct"
)
In [32]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'ct'

generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 3

# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
    pipe,
    prompt=validation_prompts[0],
    num_inference_steps=4,
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale
)

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image

Consistency Distillation¶

Задание №6¶

Теперь давайте попробуем перейти к постановке дистилляции, где шаг из $x_t$ в $x_s$ будет делаться не аналитически, а c помощью модели учителя.

Дистиляция¶

In [33]:
@torch.no_grad()
def get_xs_from_xt_with_teacher(
    x_0, x_t, t, s, # Не все эти аргументы могут быть вам нужны
    scheduler,
    prompt_embeds,
    teacher_unet,
    guidance_scale,
    **kwargs
):
    # Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
    cond_noise_pred = teacher_unet(x_t, t, prompt_embeds)

    # Для CFG нам нужно делать предсказания в unconditional случае.
    # Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
    # Извлечем эмбеды из пустого промпта и размножить их до размера батча
    uncond_input_ids = pipe.tokenizer(
        [""], return_tensors="pt", padding="max_length", max_length=77
    ).input_ids.to("cuda")

    uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(
        *prompt_embeds.shape
    )

    # Затем прогоняем модель для пустых промптов
    uncond_noise_pred = teacher_unet(x_t, t, uncond_prompt_embeds)

    # Применяем CFG формулу и получаем итоговый предикт учителя
    noise_pred = (
        uncond_noise_pred.sample + guidance_scale * (cond_noise_pred.sample - uncond_noise_pred.sample)
    )

    # Получаем x_s из x_t
    x_s = ddim_solver_step(noise_pred, x_t, t, s, scheduler)
    return x_s


# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
    get_xs_from_xt_with_teacher,
    teacher_unet=teacher_unet,
    guidance_scale=7.5
)

Pseudo-Huber Loss¶

In [34]:
def pseudo_huber_loss(
    x: torch.Tensor,
    y: torch.Tensor,
    c=0.001
):
    diff = x - y

    squared_norm = torch.norm(diff, p=2) ** 2

    loss = torch.sqrt(squared_norm + c**2) - c

    return loss
In [47]:
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32

# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")

# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [48]:
cd_loss = functools.partial(
    cm_loss_template,
    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_zero_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher
)

assert cm_unet.active_adapter == 'cd'
In [49]:
# num_grad_accum = 2 # обновляем параметры каждые 2 шага

# train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)

# cm_unet.save_pretrained('cm_unet')
In [50]:
cm_unet = PeftModel.from_pretrained(
    model=unet,
    model_id=f'{MODEL_DIR}/cd',
    config=lora_config,
    adapter_name="cd"
)
In [51]:
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'cd'

generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 0

images = consistency_sampling(
    pipe,
    prompt=validation_prompts[0],
    num_inference_steps=4,
    generator=generator,
    num_images_per_prompt=4,
    guidance_scale=guidance_scale
)

visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
In [52]:
guidance_scale = 0

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = consistency_sampling(
        pipe,
        prompt=prompt,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4,
        guidance_scale=guidance_scale
    )

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Задание №7 (0.25 балла, сдается в контесте)¶

Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.

In [53]:
def get_multi_boundary_timesteps(timesteps, num_boundaries=4, num_timesteps=1000):
    """
    For a batch of timesteps, determine the corresponding boundary points.
    params:
        timesteps: torch.Tensor(batch_size, device='cuda')
    returns:
        boundary_timesteps: torch.Tensor(batch_size, device='cuda')
    """
    boundary_timesteps = torch.zeros_like(timesteps)
    if num_boundaries < 2:
        return boundary_timesteps

    boundary_points = torch.linspace(0, num_timesteps - 1, num_boundaries + 1).long().to(timesteps.device)

    for i in range(len(boundary_points) - 1):
        boundary_timesteps[
            (boundary_points[i] <= timesteps) & (timesteps < boundary_points[i + 1])
        ] = boundary_points[i]

    boundary_timesteps[timesteps >= boundary_points[-1]] = boundary_points[-2]

    return boundary_timesteps
In [54]:
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32

cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")

# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
In [55]:
multi_cd_loss = functools.partial(
    cm_loss_template,
    loss_fn=pseudo_huber_loss,
    get_boundary_timesteps=get_multi_boundary_timesteps,
    get_xs_from_xt=get_xs_from_xt_with_teacher
)

assert cm_unet.active_adapter == 'multi-cd'
In [56]:
# num_grad_accum = 2 # обновляем параметры каждые 2 шага

# train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)

# cm_unet.save_pretrained('cm_unet')
In [57]:
cm_unet = PeftModel.from_pretrained(
    model=unet,
    #model_id=f'{MODEL_DIR}/multi-cd',
    model_id='/kaggle/input/notebooke7f6989e08/cm_unet/multi-cd',
    config=lora_config,
    adapter_name="multi-cd"
)
In [59]:
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == 'multi-cd'

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(2)

    images = pipe(
        prompt=prompt,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4,
        guidance_scale=guidance_scale
    ).images

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [60]:
cm_unet.push_to_hub(
    "jmpleo/cv-week-2024",
    token="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"
)
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]
Out[60]:
CommitInfo(commit_url='https://huggingface.co/jmpleo/cv-week-2024/commit/d4a323d8efe76da7e1b727c4c1376b86ad443880', commit_message='Upload model', commit_description='', oid='d4a323d8efe76da7e1b727c4c1376b86ad443880', pr_url=None, pr_revision=None, pr_num=None)
In [61]:
loaded_cm_unet = PeftModel.from_pretrained(
    unet,
    "jmpleo/cv-week-2024",
    token="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",
    subfolder='multi-cd',
    adapter_name="multi-cd",
)
multi-cd/adapter_config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]
adapter_model.safetensors:   0%|          | 0.00/135M [00:00<?, ?B/s]
In [62]:
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == 'multi-cd'

guidance_scale = 1

for prompt in validation_prompts:
    generator = torch.Generator(device="cuda").manual_seed(1)

    images = pipe(
        prompt=prompt,
        num_inference_steps=4,
        generator=generator,
        num_images_per_prompt=4,
        guidance_scale=guidance_scale
    ).images

    visualize_images(images)
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
  0%|          | 0/4 [00:00<?, ?it/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

P.S. Некоторые примеры плохих генераций, которые могут возникать при выполнении задания¶

Неправильный сэмплинг¶

img img

Ошибки в обучении¶

img img

Необученная модель¶

img